{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T10:39:47.417718Z",
     "start_time": "2020-03-23T10:39:46.736236Z"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "sys.path.insert(0, '..')\n",
    "from iharm.utils.exp import load_config_file\n",
    "\n",
    "cfg = load_config_file('../config.yml', return_edict=True)\n",
    "device = torch.device('cuda:1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Init dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T10:39:47.775697Z",
     "start_time": "2020-03-23T10:39:47.420359Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded HAdobe5k dataset with 2160 images\n"
     ]
    }
   ],
   "source": [
    "from iharm.data.hdataset import HDataset\n",
    "from iharm.data.transforms import HCompose, LongestMaxSizeIfLarger\n",
    "\n",
    "# Possible choices: HFlickr, HDay2Night, HCOCO, HAdobe5k\n",
    "dataset_name = 'HAdobe5k'\n",
    "dataset = HDataset(\n",
    "    dataset_path=cfg.get(f'{dataset_name.upper()}_PATH'),\n",
    "    split='test',\n",
    "    augmentator=HCompose([LongestMaxSizeIfLarger(1024)])\n",
    ")\n",
    "\n",
    "print(f'Loaded {dataset_name} dataset with {len(dataset)} images')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Init model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T10:39:52.682759Z",
     "start_time": "2020-03-23T10:39:49.223122Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Load checkpoint from path: /hdd0/harmonization_exps/models/hrnet_dih_001_adobe5k_last.pth\n"
     ]
    }
   ],
   "source": [
    "from iharm.inference.utils import load_model, find_checkpoint\n",
    "from iharm.inference.predictor import Predictor\n",
    "\n",
    "checkpoint_path = find_checkpoint(cfg.MODELS_PATH, 'hrnet_dih_001_adobe5k_last')\n",
    "\n",
    "net = load_model('hrnet_dih', checkpoint_path, verbose=True)\n",
    "predictor = Predictor(net, device, with_flip=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T10:43:06.408079Z",
     "start_time": "2020-03-23T10:39:52.685202Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Testing on HAdobe5k: 100%|██████████| 2160/2160 [03:13<00:00, 11.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "                    |    N     |   MSE    |   fMSE   |   PSNR   | AvgTime, ms  |\n",
      "--------------------------------------------------------------------------------\n",
      "HAdobe5k            |   2160   |  36.06   |  192.47  |  36.11   |     52.0     |\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from iharm.inference.metrics import MetricsHub, MSE, fMSE, PSNR, N, AvgPredictTime\n",
    "from iharm.inference.evaluation import evaluate_dataset\n",
    "\n",
    "dataset_metrics = MetricsHub([N(), MSE(), fMSE(), PSNR(), AvgPredictTime()],\n",
    "                             name=dataset_name)\n",
    "\n",
    "evaluate_dataset(dataset, predictor, dataset_metrics)\n",
    "\n",
    "print(dataset_metrics.get_table_header())\n",
    "print(dataset_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single sample eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T10:23:49.279021Z",
     "start_time": "2020-03-23T10:23:47.951249Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "sample_id = 12\n",
    "\n",
    "sample = dataset.get_sample(sample_id)\n",
    "sample = dataset.augment_sample(sample)\n",
    "\n",
    "figsize=(10, 8)\n",
    "print('Composite image:')\n",
    "plt.figure(figsize=figsize)\n",
    "plt.imshow(sample['image'])\n",
    "plt.show()\n",
    "\n",
    "print('Composite object mask:')\n",
    "plt.figure(figsize=figsize)\n",
    "plt.imshow(sample['object_mask'])\n",
    "plt.show()\n",
    "\n",
    "pred = predictor.predict(sample['image'], sample['object_mask']).astype(np.uint8)\n",
    "\n",
    "print('Predicted harmonized image:')\n",
    "plt.figure(figsize=figsize)\n",
    "plt.imshow(pred)\n",
    "plt.show()\n",
    "\n",
    "print('Original image:')\n",
    "plt.figure(figsize=figsize)\n",
    "plt.imshow(sample['target_image'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-23T10:43:17.370225Z",
     "start_time": "2020-03-23T10:43:06.411308Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "595fa698e18040228a76498cbf038786",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import cv2\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import trange\n",
    "\n",
    "output_vis_path = Path('/hdd0/harmonization_exps/visualization/hrnet_dih_001_adobe5k_last/adobe5k')\n",
    "output_vis_path.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "for sample_id in trange(100):\n",
    "    sample = dataset.get_sample(sample_id)\n",
    "    sample = dataset.augment_sample(sample)\n",
    "    \n",
    "    pred = predictor.predict(sample['image'], sample['object_mask']).astype(np.uint8)\n",
    "    \n",
    "    splitter = np.full((pred.shape[0], 25, 3), 255, dtype=np.uint8)\n",
    "    \n",
    "    vis_sample = np.concatenate((sample['image'], splitter, pred,\n",
    "                                 splitter, sample['target_image']), axis=1)\n",
    "    \n",
    "    vis_sample = cv2.cvtColor(vis_sample, cv2.COLOR_RGB2BGR)\n",
    "    cv2.imwrite((output_vis_path / dataset.dataset_samples[sample_id]).as_posix(),\n",
    "                vis_sample, [cv2.IMWRITE_JPEG_QUALITY, 85])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "294px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}